#!/usr/bin/env python2
# Copyright (c) 2015-2016, NVIDIA CORPORATION.  All rights reserved.

"""
Classify an image using individual model files

Use this script as an example to build your own tool
"""

import argparse
import os
import time
import sys
import glob
import cv2

from google.protobuf import text_format
import numpy as np
import PIL.Image
import scipy.misc
import datetime
import pandas as pd

# from caffe_test import create_submission, merge_several_folds_mean, load_test

os.environ['GLOG_minloglevel'] = '2'  # Suppress most caffe output
sys.path.append('/home/dell/caffe-master/python')
import caffe
from caffe.proto import caffe_pb2

#data_root = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection'
data_root='/media/dell/delldisk/dell/Rui/data/Kaggle/DistractedDriverDetection'

def merge_several_folds_mean(data, nfolds):
    a = np.array(data[0])
    for i in range(1, nfolds):
        a += np.array(data[i])
    a /= nfolds
    return a.tolist()


def create_submission(predictions, test_id, info):
    result1 = pd.DataFrame(predictions, columns=['c0', 'c1', 'c2', 'c3',
                                                 'c4', 'c5', 'c6', 'c7',
                                                 'c8', 'c9'])
    result1.loc[:, 'img'] = pd.Series(test_id, index=result1.index)
    now = datetime.datetime.now()
    if not os.path.isdir(data_root + '/subm'):
        os.mkdir(data_root + '/subm')
    suffix = info + '_' + str(now.strftime("%Y-%m-%d-%H-%M"))
    sub_file = os.path.join(data_root, 'subm', 'submission_' + suffix + '.csv')
    result1.to_csv(sub_file, index=False)


def get_net(caffemodel, deploy_file, use_gpu=True):
    """
    Returns an instance of caffe.Net

    Arguments:
    caffemodel -- path to a .caffemodel file
    deploy_file -- path to a .prototxt file

    Keyword arguments:
    use_gpu -- if True, use the GPU for inference
    """
    if use_gpu:
        caffe.set_mode_gpu()

    # load a new model
    return caffe.Net(deploy_file, caffemodel, caffe.TEST)


def get_transformer(deploy_file, mean_file=None):
    """
    Returns an instance of caffe.io.Transformer

    Arguments:
    deploy_file -- path to a .prototxt file

    Keyword arguments:
    mean_file -- path to a .binaryproto file (optional)
    """
    network = caffe_pb2.NetParameter()
    with open(deploy_file) as infile:
        text_format.Merge(infile.read(), network)

    if network.input_shape:
        dims = network.input_shape[0].dim
    else:
        dims = network.input_dim[:4]

    t = caffe.io.Transformer(
        inputs={'data': dims}
    )
    t.set_transpose('data', (2, 0, 1))  # transpose to (channels, height, width)

    # color images
    #if dims[1] == 3:
        # channel swap BGR2RGB
    #    t.set_channel_swap('data', (2, 1, 0))

    if mean_file:
        # set mean pixel
        with open(mean_file, 'rb') as infile:
            blob = caffe_pb2.BlobProto()
            blob.MergeFromString(infile.read())
            if blob.HasField('shape'):
                blob_dims = blob.shape
                assert len(blob_dims) == 4, 'Shape should have 4 dimensions - shape is "%s"' % blob.shape
            elif blob.HasField('num') and blob.HasField('channels') and \
                    blob.HasField('height') and blob.HasField('width'):
                blob_dims = (blob.num, blob.channels, blob.height, blob.width)
            else:
                raise ValueError('blob does not provide shape or 4d dimensions')
            pixel = np.reshape(blob.data, blob_dims[1:]).mean(1).mean(1)
            t.set_mean('data', pixel)

    return t


def load_image(path, height, width, mode='RGB',):
    """
    Load an image from disk

    Returns an np.ndarray (channels x width x height)

    Arguments:
    path -- path to an image on disk
    width -- resize dimension
    height -- resize dimension

    Keyword arguments:
    mode -- the PIL mode that the image should be converted to
        (RGB for color or L for grayscale)
    """
    img_arr_center = None
    # image = PIL.Image.open(path)
    # image = image.convert(mode)
    # image = np.array(image)

    img_arr = cv2.imread(path)# cv2 channels :RGB no need to do set_channel_swap in transformer
    #img_arr = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB)
    # img_arr = img_arr[80:400, 320:]
    img_arr = img_arr[:, 80:560]
    # crop img from center and 4 corners and flip to ten
    # image = caffe.io.oversample(img_arr,(height,width))
    img_arr = cv2.resize(img_arr, (height, width), interpolation=cv2.INTER_AREA)
    image=img_arr
    # squash
    # image = scipy.misc.imresize(image, (height, width), 'bilinear')
    return image

def load_mean_sub_image(path, height, width, mode='RGB'):
    img_arr=cv2.imread(path)
    return img_arr


def forward_pass(images, net, transformer, batch_size=None):
    """
    Returns scores for each image as an np.ndarray (nImages x nClasses)

    Arguments:
    images -- a list of np.ndarrays
    net -- a caffe.Net
    transformer -- a caffe.io.Transformer

    Keyword arguments:
    batch_size -- how many images can be processed at once
        (a high value may result in out-of-memory errors)
    """
    if batch_size is None:
        batch_size = 1

    caffe_images = []
    for image in images:
        if image.ndim == 2:
            caffe_images.append(image[:, :, np.newaxis])
        else:
            caffe_images.append(image)

    dims = transformer.inputs['data'][1:]

    scores = None
    for chunk in [caffe_images[x:x + batch_size] for x in xrange(0, len(caffe_images), batch_size)]:
        new_shape = (len(chunk),) + tuple(dims)
        if net.blobs['data'].data.shape != new_shape:
            net.blobs['data'].reshape(*new_shape)
        for index, image in enumerate(chunk):
            image_data = transformer.preprocess('data', image)
            # np.save('image_data.npy',image_data)
            net.blobs['data'].data[index] = image_data
        start = time.time()
        # output = net.forward()[net.outputs[-1]]
        # print 'net forward',net.forward()
        # print 'output prob',output
        output = net.forward()['prob']  # prob3 is for googlenet, prob is for resnet50

        if scores is None:
            scores = output
        elif scores is not None:
            scores = np.vstack((scores, output))
        if len(scores) % 10000 == 0:
            end = time.time()
            print 'Processed %s/%s images in %f seconds ...' % (len(scores), len(caffe_images), (end - start))

    return scores


def read_labels(labels_file):
    """
    Returns a list of strings

    Arguments:
    labels_file -- path to a .txt file
    """
    if not labels_file:
        print 'WARNING: No labels file provided. Results will be difficult to interpret.'
        return None

    labels = []
    with open(labels_file) as infile:
        for line in infile:
            label = line.strip()
            if label:
                labels.append(label)
    assert len(labels), 'No labels found'
    return labels


def classify(caffemodel, deploy_file, image_files,
             mean_file=None, labels_file=None, batch_size=None, use_gpu=True,
             info_string='fine_tune_googlenet_multistep_solver', DEBUG=True):
    """
    Classify some images against a Caffe model and print the results

    Arguments:
    caffemodel -- path to a .caffemodel
    deploy_file -- path to a .prototxt
    image_files -- list of paths to images

    Keyword arguments:
    mean_file -- path to a .binaryproto
    labels_file path to a .txt file
    use_gpu -- if True, run inference on the GPU
    """
    # Load the model and images
    net = get_net(caffemodel, deploy_file, use_gpu)
    transformer = get_transformer(deploy_file, mean_file)
    if mean_file is not None:
        _, channels, height, width = transformer.inputs['data']
        mean_file=mean_file
    else:
        channels=3
        height=224
        width=224
        mean_file=False
    if channels == 3:
        mode = 'RGB'
    elif channels == 1:
        mode = 'L'
    else:
        raise ValueError('Invalid number for channels: %s' % channels)
    # img_path = os.path.join(data_root, 'imgs', 'test', '*.jpg')
    img_files = glob.glob(image_files + '/*.jpg')
    print 'loading images...'
    #cache = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/cache'
    cache=data_root+'/cache'
    if DEBUG:
        if 'mean_sub_sample' in img_files:
            images = [load_mean_sub_image(image_file, height, width, mode) for image_file in img_files]
        else:
            images = [load_image(image_file, height, width, mode) for image_file in img_files]
        images_arr= np.asarray(images)
        # print 'loading image 10 crop'
        # images = [load_image_10crop(image_file, height, width, mode) for image_file in img_files]
    # np.save(cache+'/test_images.npy',images)
    else:
        if 'mean_sub_sample' in img_files:
            images = [load_mean_sub_image(image_file, height, width, mode) for image_file in img_files]
        else:
            images = np.load(cache + '/test_images_224.npy')
            #images = [load_image(image_file, height, width, mode) for image_file in img_files]
            #np.save(cache + '/test_images_224.npy', images)
        images_arr=np.asarray(images)
    labels = read_labels(labels_file)
    if mean_file is None:
        mean_vec = np.array([103.939, 116.779, 123.68], dtype=np.float32)
        mean_vec_rgb = np.array([123.68, 116.779, 103.939], dtype=np.float32)
        reshaped_mean_vec = mean_vec.reshape(3, 1, 1)
        images = images_arr - reshaped_mean_vec
        print images.shape, mean_vec.shape, reshaped_mean_vec.shape
    test_id = []
    for fl in img_files:
        flbase = os.path.basename(fl)
        test_id.append(flbase)
        # _, test_id = load_test(224, 224, 3)

    # Classify the image
    print 'classify...'
    scores = forward_pass(images, net, transformer, batch_size=batch_size)
    # print 'scores shape', scores.shape
    # print 'scores',scores[0:5]
    location = np.argmax(scores, axis=1)
    # print 'calss', location
    cnt1 = 0
    cnt2=0
    cnt3=0
    for i in range(len(scores)):
        # print 'prob', scores[i, location[i]]
        if scores[i, location[i]] > 0.9:
            cnt1 += 1
        if scores[i, location[i]] >= 0.99:
            cnt2 += 1
        if scores[i, location[i]] == 1:
            cnt3 += 1
    print 'prediction > 0.9 :',cnt1 * 100 / len(scores), '%'
    print 'prediction >=0.99:',cnt2 * 100 / len(scores), '%'
    print 'prediction = 1   :',cnt3 * 100 / len(scores), '%'
    # print 'scores', scores
    # for i in range(len(img_files)):
    #   print 'predicted clss:', scores[i].argmax()

    # test_res = merge_several_folds_mean(scores, nfolds=1)
    # print np.asarray(test_res).shape
    if DEBUG is not True:
        print 'creating submission...'
        create_submission(scores, test_id, info_string)
    """
    ### Process the results

    indices = (-scores).argsort()[:, :5]  # take top 5 results
    classifications = []
    for image_index, index_list in enumerate(indices):
        result = []
        for i in index_list:
            # 'i' is a category in labels and also an index into scores
            if labels is None:
                label = 'Class #%s' % i
            else:
                label = labels[i]
            result.append((label, round(100.0 * scores[image_index, i], 4)))
        classifications.append(result)

    for index, classification in enumerate(classifications):
        print '{:-^80}'.format(' Prediction for %s ' % image_files[index])
        for label, confidence in classification:
            print '{:9.4%} - "{}"'.format(confidence / 100.0, label)
        print
    """


snapshot_root = data_root+'/fine-tune-caffe/snapshot/'
TEST_IMG_ROOT = data_root+'/imgs/test'
sample_img = data_root+'/imgs/c4_p072_sample'
mean_sub_sampe=data_root+'/imgs/test_mean_substracted_224/bgr'
'''
# googlenet
googlenet_root = '/home/lab/xryu/code/DistractedDriverDetection/kaggle_ddd_finetune_caffe/fine_tune_googlenet/'
model_def = googlenet_root + '25vs1/googlenet_test.prototxt'
model_def_no_dropout = googlenet_root + '/8foldcvgooglenet_test_no_dropout.prototxt'
model_def_10crop = googlenet_root + 'googlenet_test_10crop.prototxt'
# snapshot_root = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/fine-tune-caffe/snapshot/'
# model_name='finetune_googlenet_auto/_iter_7000.caffemodel'
# model_name='8foldcv/googlenet_quick_0/_iter_1400.caffemodel'
# model_name = '8foldcv/googlenet_quick_2/_iter_7000.caffemodel'  # test top1 acc=1 loss 0.17 0.12 0.10
# model_name='8foldcv/googlenet_quick_3/_iter_2800.caffemodel'# test top1 acc=0.9375 0.5 0.4
# model_name='10foldcv/googlenet_quick_1/_iter_6000.caffemodel'# test top1 acc=0.95 loss=0.23
# model_name='8foldcv/googlenet_step_0/_iter_5600.caffemodel'
# 3 eopch
#model_name1 = 'multicrop/googlenet_step/_iter_8200.caffemodel' #score:0.78
#model_name2 = 'multicrop/googlenet_step/_iter_16400.caffemodel'#score:0.6
#model_name3 = 'multicrop/googlenet_step/_iter_24600.caffemodel'#score:0.5
#6epoch
model_name1 = 'multicrop/googlenet_step_6epoch/_iter_32800.caffemodel'# >0.9 =67%
model_name2 = 'multicrop/googlenet_step_6epoch/_iter_41000.caffemodel'
model_name3 = 'multicrop/googlenet_step_6epoch/_iter_49200.caffemodel'

model_weights = snapshot_root + model_name1
model_weights2 = snapshot_root + model_name2
model_weights3 = snapshot_root + model_name3

# model_weights1 = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/fine-tune-caffe/snapshot/' \
#                 'finetune_googlenet_multistep/_iter_12600.caffemodel'
# TEST_IMG_ROOT = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/imgs/test'
# sample_img = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/imgs/c4_p072_sample'
imagenet_mean_path = '/media/lab/labdisk/home/lab1314/xryu/pretrined_model/caffe/googlenet/imagenet_mean.binaryproto'
'''
'''
# resnet50
resnet50_root='/home/lab/xryu/code/DistractedDriverDetection/kaggle_ddd_finetune_caffe/fine_tune_resnet50/'
model_def=resnet50_root+'ResNet-50-deploy.prototxt'
imagenet_mean_path = '/media/lab/labdisk/home/lab1314/xryu/pretrined_model/caffe/ResNet/ResNet_mean.binaryproto'
model_name1='finetune_resnet50_quick_10vs16/_iter_2500.caffemodel'
model_name2='finetune_resnet50_quick_10vs16/_iter_5000.caffemodel'
model_name3='finetune_resnet50_quick_10vs16/_iter_7500.caffemodel'
model_weights=snapshot_root+model_name1
model_weights2 = snapshot_root + model_name2
model_weights3 = snapshot_root + model_name3
'''
'''
# vgg16
model_def = '/home/lab/xryu/code/DistractedDriverDetection/kaggle_ddd_finetune_caffe/fine_tune_vgg16/VGG_ILSVRC_16_layers_deploy.prototxt'
model_weights = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/fine-tune-caffe/finetune_vgg16_multistep/' \
                '_iter_10400.caffemodel'
TEST_IMG_ROOT = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/imgs/test'
sample_img = '/media/lab/labdisk/home/lab1314/xryu/data/DistractedDriverDetection/imgs/c4_p072_sample'
imagenet_mean_path = '/media/lab/labdisk/home/lab1314/xryu/pretrined_model/caffe/vgg/VGG_mean.binaryproto'
'''

# resnet152
resnet152_root='/home/dell/Rui/kaggle_ddd_finetune_caffe/fine_tune_resnet152'
model_def=resnet152_root+'/ResNet-152-deploy.prototxt'
imagenet_mean_path = '/media/dell/delldisk/dell/Rui/pretrined_model/caffe/ResNet/ResNet_mean.binaryproto'


model_name1 = 'multicrop/resnet152_step/_iter_16400.caffemodel'# >0.9 debug70% full predict68%
model_name2 = 'multicrop/resnet152_step/_iter_32800.caffemodel'# >0.9 debug87% full predict81%
model_name3 = 'multicrop/resnet152_step/_iter_49200.caffemodel'# >0.9 debug87% full predict82%

model_weights = snapshot_root + model_name1
model_weights2 = snapshot_root + model_name2
model_weights3 = snapshot_root + model_name3

info_string = model_weights.split('/')[-3] + '_' + model_weights.split('/')[-2] + \
              model_weights.split('/')[-1].split('.')[0]
info_string2 = model_weights2.split('/')[-3] + '_' + model_weights2.split('/')[-2] + \
              model_weights2.split('/')[-1].split('.')[0]
info_string3 = model_weights3.split('/')[-3] + '_' + model_weights3.split('/')[-2] + \
              model_weights3.split('/')[-1].split('.')[0]
# print info_string
DEBUG = True
if DEBUG:
    image_files = sample_img
else:
    image_files = TEST_IMG_ROOT

classify(caffemodel=model_weights,
         deploy_file=model_def,
         image_files=image_files,
         mean_file=imagenet_mean_path,
         labels_file=None,
         batch_size=16,
         use_gpu=True,
         info_string=info_string,
         DEBUG=DEBUG
         )

classify(caffemodel=model_weights2,
         deploy_file=model_def,
         image_files=image_files,
         mean_file=imagenet_mean_path,
         labels_file=None,
         batch_size=16,
         use_gpu=True,
         info_string=info_string2,
         DEBUG=DEBUG
         )
classify(caffemodel=model_weights3,
         deploy_file=model_def,
         image_files=image_files,
         mean_file=imagenet_mean_path,
         labels_file=None,
         batch_size=16,
         use_gpu=True,
         info_string=info_string3,
         DEBUG=DEBUG
         )

"""
if __name__ == '__main__':
    script_start_time = time.time()

    parser = argparse.ArgumentParser(description='Classification example - DIGITS')

    ### Positional arguments

    parser.add_argument('caffemodel',   default=model_weights,help='Path to a .caffemodel')
    parser.add_argument('deploy_file',  default=model_def,help='Path to the deploy file')
    parser.add_argument('image_file',default=TEST_IMG_ROOT,
                        nargs='+',
                        help='Path[s] to an image')

    ### Optional arguments

    parser.add_argument('-m', '--mean',default=imagenet_mean_path,
            help='Path to a mean file (*.npy)')
    parser.add_argument('-l', '--labels',
            help='Path to a labels file')
    parser.add_argument('--batch-size',default=16,
                        type=int)
    parser.add_argument('--nogpu',
            action='store_true',
            help="Don't use the GPU")

    args = vars(parser.parse_args())

    classify(args['caffemodel'], args['deploy_file'], args['image_file'],
            args['mean'], args['labels'], args['batch_size'], not args['nogpu'])

    print 'Script took %f seconds.' % (time.time() - script_start_time,)

"""
